{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Automatic Differentiation In Swift\n",
    "\n",
    "This notebook builds up the concepts of automatic differentiation in Swift from the constituent pieces."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: A Trivial Example"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As a warm up, we will start with a trivial example $x^2$. The derivative $\\frac{d}{dx} x^2$ is $2x$. We can represent this as follows in code."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func square(_ x: Float) -> Float {\n",
    "    return x * x\n",
    "}\n",
    "\n",
    "func square_derivative(_ x: Float) -> Float {\n",
    "    return 2 * x\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aside: Recall the Chain Rule"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we discussed before, the chain rule tells us how to differentiate composite functions, and is written: $$\\frac{d}{dx}\\left[f\\left(g(x)\\right)\\right] = f'\\left(g(x)\\right)g'(x)$$"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: A slightly more complicated example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simple polynomials are the easy case. Let's take the derivative of a more complicated function: $\\sin(x^2)$.\n",
    "\n",
    "The derivative of this expression $\\frac{d}{dx}\\sin(x^2)$ (recall the chain rule) is: $\\cos(x^2) \\cdot 2x$.\n",
    "\n",
    "In code, this is as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import Glibc\n",
    "\n",
    "func exampleFunction(_ x: Float) -> Float {\n",
    "    return sin(x * x)\n",
    "}\n",
    "\n",
    "func exampleFunctionDerivative(_ x: Float) -> Float {\n",
    "    return 2 * x * cos(x * x)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: A more efficient implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Looking at the chain rule and our derivative implementation above, we notice that there's redundant computation going on. Concretely, in both `exampleFunction` and `exampleFunctionDerivative` we compute `x * x`. (In the chain rule definition, this is $g(x)$.) As a result, we often want to do that computation only once (because in general it can be any expensive computation, and even a multiply can be expensive with large tensors).\n",
    "\n",
    "We can thus rewrite our function and its derivative as follows:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func exampleFunctionDerivativeEfficient(_ x: Float) -> (value: Float, backward: () -> Float) {\n",
    "    let xSquared = x * x\n",
    "    let value = sin(xSquared)\n",
    "    let backward = {2 * x * cos(xSquared)}  // A closure that captures xSquared\n",
    "    return (value: value, backward: backward)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You'll see that we're defining a somewhat more complex *closure* than we've seen before here."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aside: Fully general derivatives"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've actually been a little sloppy with our mathematics. To be fully correct, $\\frac{d}{dx}x^2 = 2x\\frac{d}{dx}$. This is because in mathematics, $x$ doesn't have to be a specific number, it could be itself another expression, which we'd need to use the chain rule to calculate. In order to represent this correctly in code, we need to change the type signature slightly to multiply by the \"$\\frac{d}{dx}$\", resulting in the following (we're changing the name `backward` to `deriv` here to signify that it's a little different):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func exampleFunctionValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {\n",
    "    let xSquared = x * x\n",
    "    let value = sin(xSquared)\n",
    "    let deriv = { (v: Float) -> Float in\n",
    "        let gradXSquared = v * cos(xSquared)\n",
    "        let gradX = gradXSquared * 2 * x\n",
    "        return gradX\n",
    "    }\n",
    "    return (value: value, deriv: deriv)\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 4: Rewrite using `deriv`s"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've chosen to represent the drivatives with a `deriv` closure because this allows us to rewrite the forward pass into a very regular form. Below, we rewrite the handwritten derivative above into a regular form.\n",
    "\n",
    "> Note: be sure to carefully read through the code and convince yourself that this new spelling of the `deriv` results in the exact same computation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func sinValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {\n",
    "    return (value: sin(x), deriv: {v in cos(x) * v})\n",
    "}\n",
    "\n",
    "func squareValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {\n",
    "    return (value: x * x, deriv: {v in 2 * x * v})\n",
    "}\n",
    "\n",
    "func exampleFunctionWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {\n",
    "    let (xSquared, deriv1) = squareValueWithDeriv(x)\n",
    "    let (value, deriv2) = sinValueWithDeriv(xSquared)\n",
    "    return (value: value, deriv: { v in\n",
    "        let gradXSquared = deriv2(v)\n",
    "        let gradX = deriv1(gradXSquared)\n",
    "        return gradX\n",
    "    })\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Aside: Generalizing to arbitrary expressions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Up until this point, we've been handwriting the derivatives for specific functions. But we now have a formulation that is regular and composible. (In fact, it is so regular, we can make the computer write the backwards pass for us! aka automatic differentiation.) The rules are:\n",
    "\n",
    " 1. Rewrite every expression in the forward pass into a form that computes the value like normal, and also produces an additional deriv function.\n",
    " 2. Construct a backwards pass that threads the derivs together in the reverse order."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In an abstract form, we transform a function that looks like:\n",
    "\n",
    "```swift\n",
    "func myFunction(_ arg: Float) -> Float {\n",
    "    let tmp1 = expression1(arg)\n",
    "    let tmp2 = expression2(tmp1)\n",
    "    let tmp3 = expression3(tmp2)\n",
    "    return tmp3\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "into a function that looks like this:\n",
    "\n",
    "```swift\n",
    "func myFunctionValueWithDeriv(_ arg: Float) -> (value: Float, deriv: (Float) -> Float) {\n",
    "    let (tmp1, deriv1) = expression1ValueWithDeriv(arg)\n",
    "    let (tmp2, deriv2) = expression2ValueWithDeriv(tmp1)\n",
    "    let (tmp3, deriv3) = expression3ValueWithDeriv(tmp2)\n",
    "    return (value: tmp3,\n",
    "            deriv: { v in\n",
    "                let grad2 = deriv3(v)\n",
    "                let grad1 = deriv2(grad2)\n",
    "                let gradArg = deriv1(grad1)\n",
    "                return gradArg\n",
    "    })\n",
    "}\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 5: Generalize beyond unary functions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Up until now, we have been using functions that don't \"reuse\" values in the forward pass. Our running example of $\\frac{d}{dx}\\sin(x^2)$ is too simple. Let's make it a bit more complicated, and use $\\frac{d}{dx}\\sin(x^2)+x^2$ as our motivating expression going forward. From mathematics, we know that the derivative should be: $$\\frac{d}{dx}\\sin\\left(x^2\\right) + x^2 = \\left(2x\\cos\\left(x^2\\right)+2x\\right)\\frac{d}{dx}$$\n",
    "\n",
    "Let's see how we write the deriv (pay attention to the signature of the deriv for the `+` function)!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func myComplexFunction(_ x: Float) -> Float {\n",
    "    let tmp1 = square(x)\n",
    "    let tmp2 = sin(tmp1)\n",
    "    let tmp3 = tmp2 + tmp1\n",
    "    return tmp3\n",
    "}\n",
    "\n",
    "func plusWithDeriv(_ x: Float, _ y: Float) -> (value: Float, deriv: (Float) -> (Float, Float)) {\n",
    "    return (value: x + y, deriv: {v in (v, v)})  // Value semantics are great! :-)\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "func myComplexFunctionValueWithDeriv(_ x: Float) -> (value: Float, deriv: (Float) -> Float) {\n",
    "    let (tmp1, pb1) = squareValueWithDeriv(x)\n",
    "    let (tmp2, pb2) = sinValueWithDeriv(tmp1)\n",
    "    let (tmp3, pb3) = plusWithDeriv(tmp2, tmp1)\n",
    "    return (value: tmp3,\n",
    "            deriv: { v in\n",
    "        // Initialize the gradients for all values at zero.\n",
    "        var gradX = Float(0.0)\n",
    "        var grad1 = Float(0.0)\n",
    "        var grad2 = Float(0.0)\n",
    "        var grad3 = Float(0.0)\n",
    "        // Add the temporaries to the gradients as we run the backwards pass.\n",
    "        grad3 += v\n",
    "        let (tmp2, tmp1b) = pb3(grad3)\n",
    "        grad2 += tmp2\n",
    "        grad1 += tmp1b\n",
    "        let tmp1a = pb2(grad2)\n",
    "        grad1 += tmp1a\n",
    "        let tmpX = pb1(grad1)\n",
    "        gradX += tmpX\n",
    "        // Return the computed gradients.\n",
    "        return gradX\n",
    "    })\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "// Helper method\n",
    "func square(_ x: Float) -> Float {\n",
    "    return x * x\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Non-unary functions (e.g. `+`) have a deriv that returns a tuple that corresponds to their arguments. This allows gradients to flow upwards in a pure functional manner.\n",
    "\n",
    "In order to handle the re-use of intermediary values (in this case, the expression $x^2$), we need to introduce 2 additional concepts:\n",
    "\n",
    " 1. **Sum**: We need to sum the derivatives produced by $\\frac{d}{dx}x^2$ to the values produced from $\\frac{d}{dx}\\sin\\left(x^2\\right)$ in order to correctly compute the derivative value of $\\frac{d}{dx}\\left(\\sin\\left(x^2\\right) + x^2\\right)$.\n",
    " 2. **Zero**: As a result, we need to initialize the derivatives for each variable to a value: zero!\n",
    "\n",
    "We now have all the pieces required for automatic differentiation in Swift. Let's see how they come together."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6: Automatic Differentiation in Swift\n",
    "\n",
    "When you annotate a function `@differentiable`, the compiler will take your function and generate a second function that corresponds to the `...ValueWithDeriv` functions we wrote out by hand above using the simple transformation rules.\n",
    "\n",
    "You can access these auto-generated function by calling `valueWithPullback`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@differentiable\n",
    "func myFunction(_ x: Float) -> Float {\n",
    "    return x * x\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.0\r\n",
      "(Float) -> Float\r\n"
     ]
    }
   ],
   "source": [
    "let (value, deriv) = valueWithPullback(at: 3, in: myFunction)\n",
    "print(value)\n",
    "print(type(of: deriv))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 7: gradient\n",
    "\n",
    "Now that we have a deriv, how to we \"kick off\" the deriv computation to actually compute the derivative? We use the constant value `1.0`!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "6.0\n"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "deriv(1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have no re-implemented the `gradient` function."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 8: Generalized Differentiability & Protocols"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So far, we've been looking at functions operating on scalar (`Float`) values, but you can take derivatives of functions that operate on vectors (aka higher dimensions) too. In order to support this, you need your type to conform to the `Differentiable` protocol, which often involves ensuring your type conforms to the [`AdditiveArithmetic` protocol](https://github.com/apple/swift/blob/0c452616820bfbc4f3197dd418c74adadc830b5c/stdlib/public/core/Integers.swift#L31). The salient bits of that protocol are:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```swift\n",
    "public protocol AdditiveArithmetic : Equatable {\n",
    "  /// The zero value.\n",
    "  ///\n",
    "  /// - Note: Zero is the identity element for addition; for any value,\n",
    "  ///   `x + .zero == x` and `.zero + x == x`.\n",
    "  static var zero: Self { get }\n",
    "  /// Adds two values and produces their sum.\n",
    "  ///\n",
    "  /// - Parameters:\n",
    "  ///   - lhs: The first value to add.\n",
    "  ///   - rhs: The second value to add.\n",
    "  static func +(lhs: Self, rhs: Self) -> Self\n",
    "  \n",
    "  //...\n",
    "}\n",
    "```\n",
    "\n",
    "> Note: The [`Differentiable`](https://github.com/apple/swift/blob/0c452616820bfbc4f3197dd418c74adadc830b5c/stdlib/public/core/AutoDiff.swift#L102) protocol is slightly more complicated in order to support non-differentiable member variables, such as  activation functions and other non-differentiable member variables."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Next up: The `Layer` protocol!"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Swift",
   "language": "swift",
   "name": "swift"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
